function results = SyntheticControl(Y1, Y0, X1raw, X0raw, PTS, vin, win)

% [w, Ys, Xs, vout] = SyntheticControl(Y1, Y0, X1, X0, PTS, v0, w0)
%
% Inputs:
% Y1  :  T x 1 vector, outcome data for the treated unit
% Y0  :  T x n matrix, outcome data for the control group
% X1  :  K x 1 vector, predictors for the treated unit (should be pretreatment data)
% X0  :  K x n matrix, predictors for the treated unit (should be pretreatment data)
% PTS :  A vector such that Y1(PTS) and Y0(PTS) are the preteatment outcome data
%
% v0 : Optional input with the weights each predictor uses in the minimization problem
%      If not specified, v is determined such that (Y1 - Ys)'*V*(Y1 - Ys) is minimized
%      on the pretreatment sample.
%
% w0 : Optional input with the weights each control uses in the minimization problem
%      If not specified, w is determined such that (X1 - X0*w)'*V*(X1 - X0*w) is minimized

% Outputs:
% w   : n x 1 vector with the weights allocated to each unit to form the synthetic control
% Ys  : T x 1 vector with the outcome data for the synthetic control, Ys = Y0*w;
% Xs  : K x 1 vector with the predictors data for the synthetic control, Xs = X0*w;
% v   : K x 1 vector of weights used in the optimization
% MSE : scalar with the pretreatment sample MSE


% Basic error check
[T, check] = size(Y1);
if check > 1
    error('Y1 must be a Tx1 vector');
end
[K, check] = size(X1raw);
if check > 1
    error('X1 must be a K x 1 vector');
end
[TY0, n] = size(Y0);
if TY0 ~= T
    error('Y0 must be have the same number of rows as Y1');
end
[KX0, nX0] = size(X0raw);
if KX0 ~= K
    error('X0 must be have the same number of rows as X1');
end
if nX0 ~= n
    error('X0 must be have the same number of columns as Y0');
end

optimizeV = true;
if nargin >= 6 && ~isempty(vin)    
    [Kv, check] = size(vin);
    if Kv ~= K || check ~= 1
        error('v must be a K x 1 vector');
    end    
    
    optimizeV = false;
    v = vin/vin(1);
    v(1) = [];
end

optimizeW = true;
if nargin >= 7 && ~isempty(win) 
    [nw, check] = size(win);
    if nw ~= n || check ~= 1
        error('w must be a n x 1 vector');
    end    
    
    optimizeW = false;
    w = win/sum(win);    
end

% Data normalization (cross-sectionally)
W = zscore([X0raw X1raw]')';
X0 = W(:, 1:n);
X1 = W(:, end);

% Pretreatment outcome data
Y0pre = Y0(PTS, :);
Y1pre = Y1(PTS, :);

if optimizeV
    disp('SyntheticControl: Optimizing V (... be patient!)')
    vsub = ones(K-1,1);
    opsV = optimset('Algorithm', 'active-set', 'Display', 'final', 'MaxFunEvals', 1000);
    [vopt, ~, exitflag] = fmincon(@SSRfun, vsub, [], [], [], [], zeros(K-1, 1), [], [], opsV, X1, X0, Y1pre, Y0pre);
    if exitflag == 0
        disp('SyntheticControl: Optimizing V (trying again... be patient!)')
        [vopt, ~, exitflag] = fmincon(@SSRfun, vopt, [], [], [], [], zeros(K-1, 1), [], [], opsV, X1, X0, Y1pre, Y0pre);
        if exitflag == 0
            disp('SyntheticControl: Optimizing V (still trying... be patient!)')
            vopt = fmincon(@SSRfun, vopt, [], [], [], [], zeros(K-1, 1), [], [], opsV, X1, X0, Y1pre, Y0pre);
        end
    end
    % display(sprintf('%15.4f',fminv));
    % FMINCON(FUN,X0,A,B,Aeq,Beq,LB,UB,NONLCON,OPTIONS)
    v = vopt;
    disp('... done!')    
end

% Now recover W-weights
if optimizeW
    [SSR, w] = SSRfun(v, X1, X0, Y1pre, Y0pre);
else
    SSR = SSRonly(Y1pre, Y0pre, w);
    
end

results.w   = w/sum(w);
results.MSE = SSR/length(PTS);
results.Ys = Y0*w;
results.Xs = X0raw*w;
results.Xnorms = X0*w;
results.Xnorm  = X1;
results.v  = [1;v]/(sum(v) + 1);

function [SSR, weights] = SSRfun(vv, X1, X0, Z1, Z0)
V = diag([1; vv]);
H = X0'*V*X0;
% Enforce a symmetric Hessian at the highest accuracy
H = (H + H')/2;
f = - X1'*V*X0;
n = size(Z0, 2);
opsquad = optimset('LargeScale', 'Off', 'Display', 'Off');
w0 = ones(n,1)/n; % Equal weights as an initial condition
w  = quadprog(H, f, [], [], ones(1,n), 1, zeros(n,1), ones(n,1), w0, opsquad);
w = abs(w);
SSR = SSRonly(Z1, Z0, w);
if nargout > 1
    weights = w;
end

function ssr = SSRonly(Z1, Z0, weights)
e = Z1 - Z0*weights;
ssr = sum(e.^2);